package com.mlamp.广度优先遍历算法;

import javax.swing.tree.TreeNode;
import java.util.*;

public class 子树中标签相同的节点数 {

    public static void main(String[] args) {
        子树中标签相同的节点数 instance = new 子树中标签相同的节点数();

        int[][] edges = new int[][]{
                new int[]{0, 1},
                new int[]{0, 2},
                new int[]{1, 4},
                new int[]{1, 5},
                new int[]{2, 3},
                new int[]{2, 6}

        };
        String lables = "abaedcd";
        instance.countSubTrees(7, edges, lables);
    }

    private static final class Node {
        int value;
        List<Node> children;
        char label;
        int labelNum;

        public Node(int value, List<Node> children, char label) {
            this.value = value;
            this.children = children;
            this.label = label;
        }

        public void addChildren(Node... childs) {
            if (children == null) {
                children = new ArrayList<Node>();
            }
            if (childs != null) {
                for (Node node : childs) {
                    children.add(node);
                }
            }
        }

        public Node(int value) {
            this.value = value;
        }
    }

    public int[] countSubTrees(int n, int[][] edges, String labels) {
        if (labels.isEmpty()) return new int[0];
        char[] lables = labels.toCharArray();
        int lableSize = lables.length;
        int edgeSize = edges.length;
        if (lableSize - 1 != edgeSize) return new int[0];
        if (lableSize != n) return new int[0];
        Map<Integer, Node> cache = new HashMap<>();
        Node root = null;
        for (int[] edge : edges) {
            Integer from = Integer.valueOf(edge[0]);
            Integer to = Integer.valueOf(edge[1]);
            cache.putIfAbsent(from, new Node(from));
            cache.putIfAbsent(to, new Node(to));
            Node fromNode = cache.get(from);
            Node toNode = cache.get(to);
            if (root == null) root = fromNode;
            fromNode.addChildren(toNode);
        }
        //result array
        int[] result = new int[n];
        Queue<Node> queue = new LinkedList<>();
        Stack<Node> stack = new Stack<>();
        queue.add(root);
        int labelIndex = 0;
        while (!queue.isEmpty()) {
            int size = queue.size();
            for (int index = 0; index < size; index++) {
                Node poll = queue.poll();
                stack.add(poll);
                poll.label = lables[labelIndex++];
                List<Node> children = poll.children;
                if (children == null) continue;
                queue.addAll(children);
            }
        }
        return new int[0];
    }

}
